{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "10679f41",
   "metadata": {},
   "source": [
    "### 7.5 卷积神经网络\n",
    "\n",
    "卷积神经网络和多层感知机这类的深度模型相比，最大的特点就在于卷积层。使用卷积层代替全连接层的主要优势在于可以更好地处理空间信息。在图像分类中，使用全连接层可能会丢失空间信息。比如，在一张图像中，如果绿色部分占据较大面积，而蓝色部分占据较小面积，那么使用全连接层可能会把这两种颜色视为同等重要。而使用卷积层，可以通过卷积核在输入图像上滑动的方式，提取局部信息。这样，卷积层可以根据局部信息更好地学习特征，并保留图像的空间信息。此外，使用卷积层还可以减少网络参数的数量。在卷积层中，一个参数就是卷积核的权值，而在全连接层中，每个节点都需要一个权值。因此，使用卷积层可以大幅减少网络参数的数量，使得网络更加简洁。\n",
    "\n",
    "举个例子，假设输入是一张 $28\\times28$ 的图像，输出是 10 个类别。如果使用全连接层，那么第一层就需要 $28\\times28\\times10=7840$ 个参数，而使用卷积层，只需要 $5\\times5\\times10=250$ 个参数。可以看出，使用卷积层的参数数量大大减少，对于较小的数据集来说，使用卷积层可以避免过拟合的问题。\n",
    "\n",
    "经过前面小节的学习，想必你已经充分理解为什么使用卷积层可以更好地处理空间信息，减少网络参数的数量，使得网络更容易训练。恰恰是这些优势使得，在图像分类等任务中，使用卷积层成为非常常见的做法。本节咱们就来学习最经典的深度卷积神经网络，由Yann LeCun在1998年提出的LeNet。它是为了解决手写数字识别问题而设计的，并且在当时取得了很好的成功。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a2e9b06",
   "metadata": {},
   "source": [
    "### 7.5.1 LeNet 模型结构\n",
    "LeNet的结构如下图所示：\n",
    "\n",
    "<img src=\"../images/7-5-1.png\" width=\"90%\"></img>\n",
    "\n",
    "- 输入层：LeNet的输入层接受28x28像素的灰度图像。\n",
    "- 卷积层1：这一层包含6个卷积核，每个卷积核的大小为5x5，卷积步长为1。该层使用Sigmoid激活函数。\n",
    "- 池化层1：这一层使用2x2的最大池化窗口，步长为2。这一层的作用是降低图像尺寸，并保留最重要的特征。\n",
    "- 卷积层2：这一层包含16个卷积核，每个卷积核的大小为5x5，卷积步长为1。该层使用Sigmoid激活函数。\n",
    "- 池化层2：这一层使用2x2的最大池化窗口，步长为2。这一层的作用与池化层1相同。\n",
    "- 全连接层：这一层包含120个节点，使用Sigmoid激活函数。\n",
    "- 全连接层：这一层包含84个节点，使用Sigmoid激活函数。\n",
    "- 输出层：这一层包含10个节点，对应0~9的十个数字。\n",
    "\n",
    "**这里为加速收敛将Sigmoid激活函数替换为ReLU函数。**\n",
    "\n",
    "下面是LeNet模型的代码实现："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d87c584f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class LeNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNet, self).__init__()\n",
    "        # 卷积层1：输入1个通道，输出6个通道，卷积核大小为5x5\n",
    "        self.conv1 = nn.Conv2d(1, 6, 5)\n",
    "        # 卷积层2：输入6个通道，输出16个通道，卷积核大小为5x5\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        # 全连接层1：输入16x5x5=400个节点，输出120个节点\n",
    "        self.fc1 = nn.Linear(16 * 4 * 4, 120)\n",
    "        # 全连接层2：输入120个节点，输出84个节点\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        # 输出层：输入84个节点，输出10个节点\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 使用Sigmoid激活函数，并进行最大池化\n",
    "        x = torch.relu(self.conv1(x))\n",
    "        x = nn.functional.max_pool2d(x, 2)\n",
    "        # 使用Sigmoid激活函数，并进行最大池化\n",
    "        x = torch.relu(self.conv2(x))\n",
    "        x = nn.functional.max_pool2d(x, 2)\n",
    "        # 将多维张量展平为一维张量\n",
    "        x = x.view(-1, 16 * 4 * 4)\n",
    "        # 全连接层\n",
    "        x = torch.relu(self.fc1(x))\n",
    "        # 全连接层\n",
    "        x = torch.relu(self.fc2(x))\n",
    "        # 全连接层\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdbf80ac",
   "metadata": {},
   "source": [
    "### 7.5.2. 使用MNIST数据集的模型训练\n",
    "让我们用MNIST数据集实操一下LeNet的模型训练流程。主要步骤如下：\n",
    "\n",
    "* 下载MNIST数据集并将其加载到程序中。\n",
    "* 定义LeNet模型。\n",
    "* 定义损失函数和优化器。使用交叉熵损失函数和随机梯度下降（SGD）优化器。\n",
    "* 遍历数据集并使用训练数据训练LeNet模型。在训练过程中，需要计算损失值，并使用优化器更新模型的参数。\n",
    "* 在训练过程结束后，使用测试数据评估LeNet模型的准确性。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "04cbc625",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0 Loss: 2.7312533989004866 Acc: 0.2926\n",
      "Epoch: 2 Loss: 2.0932299511431336 Acc: 0.9009 \n",
      "Epoch: 4 Loss: 1.669355056561876 Acc: 0.9475  \n",
      "Epoch: 6 Loss: 1.4936095757677041 Acc: 0.9658 \n",
      "Epoch: 8 Loss: 1.3853249114034598 Acc: 0.9683 \n",
      "100%|██████████| 10/10 [02:19<00:00, 13.96s/it]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8+yak3AAAACXBIWXMAAAsTAAALEwEAmpwYAAAljklEQVR4nO3deXxU5b3H8c8vM1nIShLCkrAEEQVEAgoqWBHrjoqoteqttqLVLlqrttdrtdutvVevtrbaWpVbN2rd6nLFKlqt2uBWAQ0oBCpKgIQAIQsJCVnnuX/MEJIQSIBJTmbyfb9e85qzPOfMLwfy5eGZs5hzDhERiXwxXhcgIiLhoUAXEYkSCnQRkSihQBcRiRIKdBGRKOH36oMHDRrkcnNzvfp4EZGItGzZsm3OuazO1nkW6Lm5uSxdutSrjxcRiUhmtn5v6zTkIiISJRToIiJRQoEuIhIlPBtDF5Ho1tTURHFxMfX19V6XEpESEhIYPnw4sbGx3d5GgS4iPaK4uJiUlBRyc3MxM6/LiSjOOcrLyykuLmb06NHd3k5DLiLSI+rr68nMzFSYHwAzIzMzc7//d6NAF5EeozA/cAdy7CJuyGXt1hoWLi8lNzORUZlJjMpMJDMpTn9xRKTfi7hALyyt4fdvfkagzW3ck+P9jMpMDL2S2oX9kJQEYmIU9iL9UXJyMjt27PC6jF4TcYF+Tl42px0xhOLKnWwor6OovJb15XWsL69ldWkNr6/aQlPL7rSP98cwKjORkRmhoB+UxKiMRHIzk8gemIDfp1EnEYkOERfoAPF+H2OykhmTlbzHuuaWAKXb61nfGva7Ar+Od9aWUd8UaG3rjzFGZCQyMiOxXa9+VGYSIzIGEO/39eaPJSI9xDnHTTfdxKJFizAzfvzjH3PRRRdRWlrKRRddRHV1Nc3Nzdx///3MmDGDK6+8kqVLl2JmXHHFFdxwww1e/wjdEpGBvi9+XwwjMhIZkZHIl8YOarcuEHBsrWloDfmi8lrWVwR79x+tr6Smobm1rRlkpw3oMIyzO/QT46Lu0In0mP98aSWrNlWHdZ8TslP52TlHdKvt888/T0FBAcuXL2fbtm1MmzaNmTNn8sQTT3D66adz66230tLSQl1dHQUFBZSUlPDpp58CUFVVFda6e1K/SqWYGGNoWgJD0xI49pDMduucc1TUNrYGfNG2OjZUBEP/tZWbqahtbNc+KyWe6YdkcteFk9STF+nj3nnnHS655BJ8Ph9DhgzhxBNPZMmSJUybNo0rrriCpqYm5s6dy+TJkznkkEP44osv+N73vsdZZ53Faaed5nX53davAn1fzIzM5Hgyk+M5amT6Huur65vajdmv3bqDFz4uYUhqPLeeNcGDikUiR3d70r1t5syZ5Ofn8/LLL3P55Zdz44038vWvf53ly5fz2muv8cADD/DMM8/w8MMPe11qtyjQuyk1IZaJOWlMzElrXZYc7+d/F6/jhLFZzDys09sTi0gfcMIJJ/Dggw/yjW98g4qKCvLz87nrrrtYv349w4cP56qrrqKhoYGPPvqI2bNnExcXxwUXXMDhhx/OpZde6nX53aZAPwi3njWeD74o58ZnlvPq9ScwKDne65JEpBPnnXce77//Pnl5eZgZd955J0OHDuWxxx7jrrvuIjY2luTkZBYsWEBJSQnz5s0jEAieQHH77bd7XH33mXOu61Y9YOrUqS4aHnCxenM1c37/Ll86dBAPfWOqLnASCSksLGT8+PFelxHROjuGZrbMOTe1s/Y6CfsgjRuayq2zx/Pm6q0seH+vDxIREelxCvQw+Pr0UXx53GD+65VCCkvDe2qWiEh3KdDDwMy46yuTSBsQy3VPfkx9U4vXJYlIP6RAD5PM5Hh+fWEen23dwX+9XOh1OSLSDynQw2jmYVlcdcJo/vTBel5ftcXrckSkn1Ggh9m/nz6OiTmp3PTscrZU69FbItJ7FOhhFueP4Z6Lp1DfFODGZwoIBLw5LVRE+p8uA93MRpjZW2a2ysxWmtn3O2kzy8y2m1lB6PXTnik3MozJSubncybw7tpy5i/+wutyRKSHNTc3d92oF3Snh94M/MA5NwE4DrjGzDq7ecli59zk0OsXYa0yAn116ghmHzmUX722hhXFVV6XI9JvzZ07l6OPPpojjjiC+fPnA/Dqq69y1FFHkZeXx8knnwzAjh07mDdvHkceeSSTJk3iueeeA4IPydjl2Wef5fLLLwfg8ssv59vf/jbHHnssN910Ex9++CHTp09nypQpzJgxgzVr1gDQ0tLCD3/4QyZOnMikSZP43e9+x5tvvsncuXNb9/v6669z3nnnHfTP2uWl/865UqA0NF1jZoVADrDqoD89ipkZt583iYIN+Vz35Me8fN0JJMXrTgvSTy26GTZ/Et59Dj0Szryjy2YPP/wwGRkZ7Ny5k2nTpnHuuedy1VVXkZ+fz+jRo6moqADgtttuIy0tjU8+CdZZWVnZ5b6Li4t577338Pl8VFdXs3jxYvx+P2+88Qa33HILzz33HPPnz6eoqIiCggL8fj8VFRWkp6fz3e9+l7KyMrKysnjkkUe44oorDu54sJ9j6GaWC0wB/tnJ6ulmttzMFplZp7dWM7OrzWypmS0tKyvb/2ojTFpiLL+9eAobKur42cKVXpcj0i/de++95OXlcdxxx7Fx40bmz5/PzJkzGT16NAAZGRkAvPHGG1xzzTWt26Wn73nX1Y4uvPBCfL7g7bO3b9/OhRdeyMSJE7nhhhtYuXJl636/9a1v4ff7Wz/PzLjssst4/PHHqaqq4v333+fMM8886J+1211GM0sGngOud851vBzyI2CUc26Hmc0G/g8Y23Efzrn5wHwI3svlQIuOJMeMzuDakw7l3jfXMvOwLObkZXtdkkjv60ZPuie8/fbbvPHGG7z//vskJiYya9YsJk+ezOrVq7u9j7b3Z6qvb3/mWlJSUuv0T37yE0466SReeOEFioqKmDVr1j73O2/ePM455xwSEhK48MILWwP/YHSrh25msQTD/M/Ouec7rnfOVTvndoSmXwFizWxQx3b91XUnj+WokQO59YVP2FhR53U5Iv3G9u3bSU9PJzExkdWrV/PBBx9QX19Pfn4+69atA2gdcjn11FO57777WrfdNeQyZMgQCgsLCQQCvPDCC/v8rJycHAAeffTR1uWnnnoqDz74YOsXp7s+Lzs7m+zsbH75y18yb968sPy83TnLxYCHgELn3N17aTM01A4zOya03/KwVBgF/L7gqYzOwQ1PF9DcEuh6IxE5aGeccQbNzc2MHz+em2++meOOO46srCzmz5/P+eefT15eHhdddBEAP/7xj6msrGTixInk5eXx1ltvAXDHHXdw9tlnM2PGDIYNG7bXz7rpppv40Y9+xJQpU9qd9fLNb36TkSNHMmnSJPLy8njiiSda133ta19jxIgRYbsrZZe3zzWzLwGLgU+AXUl0CzASwDn3gJldC3yH4BkxO4EbnXPv7Wu/0XL73P3xYkEJ33+qgOtPGcv1pxzmdTkiPUq3z+3atddey5QpU7jyyis7Xb+/t8/tzlku7wD7vMm3c+73wO+72ld/d+7kHP7xrzLu/ftnHH/oIKblZnhdkoh45OijjyYpKYlf//rXYdunrhTtZb84dyLD0xO5/qkCtu9s8rocEfHIsmXLyM/PJz4+fE86U6D3suR4P/deMoUt1fXc8sInePXEKJHeoL/fB+5Ajp0C3QOTRwzkhlMP4+UVpTy7rNjrckR6REJCAuXl5Qr1A+Cco7y8nISEhP3aTpcueuTbJ45h8Wdl/GzhSqbmZjB6UFLXG4lEkOHDh1NcXEx/uIiwJyQkJDB8+PD92kYPifZQ6fadnHnPYkakJ/Lcd2YQ59d/mERk3/SQ6D5qWNoA7jh/Ep+UbOfXr6/xuhwRiXAKdI+dMXEo/3bsSB78xxe889k2r8sRkQimQO8DfnLWBMZkJXHjMwVU1DZ6XY6IRCgFeh8wIM7H7y45iqq6Jm56drnOChCRA6JA7yMmZKdy85njeKNwK49/sN7rckQkAinQ+5B5x+cy6/AsfvlyIWs213hdjohEGAV6H2Jm3PWVPFIS/Fz35MfUN7V4XZKIRBAFeh+TlRLPry7MY82WGu5Y1P2b8IuIKND7oFmHD+aK40fz6HtF/L1wi9fliEiEUKD3Uf9x5uGMH5bKvz+7gq3V9V1vICL9ngK9j4r3+/jdJZOpa2zmB39ZTiCgUxlFZN8U6H3YoYNT+OnZR7D4s2089M46r8sRkT5Ogd7HXXLMCE4/Ygh3vraaT0u2e12OiPRhCvQ+zsy44/xJZCbFc92TH1Pb0Nz1RiLSLynQI0B6Uhy/uWgy68pr+cVLq7wuR0T6KAV6hJg+JpPvzhrD00s38vKKUq/LEZE+SIEeQa4/5TDyRgzkR8+voKRqp9fliEgfo0CPILG+GO69eDItAccNTxXQolMZRaQNBXqEGZWZxG1zJ/JhUQX3vbXW63JEpA9RoEeg848aztzJ2dzz989Ytr7C63JEpI9QoEeoX8ydSPbABL7/VAHV9U1elyMifYACPUKlJsTy24umULq9nuue/JhKPbpOpN9ToEewo0el8/M5R/DOZ9s47bf5vL5Kd2YU6c8U6BHusuNG8eK1xzMoOZ6rFizlxqcL2F6nIRiR/kiBHgWOyE7jxWuO57qTx/Li8k2c9tt/8OZq9dZF+hsFepSI88dw46mH8eI1xzNwQBxXPLqUf//LcrbvVG9dpL/oMtDNbISZvWVmq8xspZl9v5M2Zmb3mtlaM1thZkf1TLnSlYk5aSz83vFce9KhPP9xCaf/Jp+312z1uiwR6QXd6aE3Az9wzk0AjgOuMbMJHdqcCYwNva4G7g9rlbJf4v0+fnj64Tz/nRmkJPi5/JEl3PzcCmp0eqNIVOsy0J1zpc65j0LTNUAhkNOh2bnAAhf0ATDQzIaFvVrZL3kjBvLS977Et08cwzNLN3L6b/JZ/FmZ12WJSA/ZrzF0M8sFpgD/7LAqB9jYZr6YPUNfPJAQ6+PmM8fx3HdmkBDn47KHPuSWFz5hh+6rLhJ1uh3oZpYMPAdc75yrPpAPM7OrzWypmS0tK1NPsTdNGZnOK9edwNUzD+HJDzdw+m/yeW/tNq/LEpEw6lagm1kswTD/s3Pu+U6alAAj2swPDy1rxzk33zk31Tk3NSsr60DqlYOQEOvjltnj+cu3phPnj+Hf/vhPfvrip3oKkkiU6M5ZLgY8BBQ65+7eS7OFwNdDZ7scB2x3zukpDH3U1NwMXrnuBK44fjR/+mA9Z9yTzwdflHtdlogcpO700I8HLgO+bGYFoddsM/u2mX071OYV4AtgLfC/wHd7plwJlwFxPn56zgSevno6MWZcPP8Dfr5wJXWN6q2LRCpzzpuHJEydOtUtXbrUk8+W9uoam7nz1TU8+l4RuZmJ3HVhHtNyM7wuS0Q6YWbLnHNTO1unK0WFxDg/P59zBE9edRzNAcdXH3yf2/66ivqmFq9LE5H9oECXVtPHZPLa9TO59NhRPPTOOmbfs5hl6yu9LktEukmBLu0kxfu5be5E/vzNY2loDnDhA+9x+yuF6q2LRAAFunTq+EMH8er1J3DRtJE8mP8FZ927mI83qLcu0pcp0GWvUhJiuf38I1lwxTHUNbZwwf3v8T+vrqahWb11kb5IgS5dmnlYFq/dMJOvHD2c+9/+nHN+9w4riqu8LktEOlCgS7ekJsRy51fyeGTeNLbvbOK8P7zHr15bo966SB+iQJf9ctLhg/nb9Scyd3IOv39rLef+/l0+LdnudVkiggJdDkBaYiy//moeD31jKuW1jcy9713ufv1fNDYHvC5NpF/TlaJyUKrqGvnPl1bxwsclZCbFMfvIYcyZnM3RI9OJiTGvyxOJOvu6UlSBLmGR/68ynl6ykTcKt9DQHGBYWgJnTxrGnLwcJuakErzHm4gcLAW69JodDc38vXALCws2kf9ZGU0tjtzMRM7Jy2ZOXjZjh6R4XaJIRFOgiyeq6hp5beVmFi7fxPuflxNwMG5oCufkZXP2pGGMykzyukSRiKNAF89traln0SebeWn5JpaG7g+TN2Ig50waxtmTshmaluBxhSKRQYEufUpxZR0vryjlpRWb+LSkGjM4JjeDc/KyOXPiUDKT470uUaTPUqBLn/V52Q7+uryUhctL+LysFl+Mcfyhg5iTl81pRwwhNSHW6xJF+hQFuvR5zjkKS2t4acUmXlq+ieLKncT5Yph1eBZzJmdz8rghDIjzeV2miOcU6BJRnHN8vLGKl5Zv4uUVpWytaSAxzscp44cwJy+bEw4bRLxf4S79kwJdIlZLwPHhugoWLt/Eok9LqaprIjXBzxkTh3JOXjbTD8nE79MFz9J/KNAlKjS1BHhn7TZeWr6Jv63cwo6GZgYlB69OPSdPV6dK/6BAl6hT39TC22u28tLy0tarU7PTEjg7L5uZY7M4atRAEuP8XpcpEnYKdIlqOxqaeWPVFl5avol//KuM5oDDH2MckZPGMbnpTMvNYFpuBulJcV6XKnLQFOjSb9TUN7FsfSVLiipYsq6SguKq1rtAjh2czLTRGRyTm8G00RnkDBzgcbUi+0+BLv1WfVMLn5Rs58N1FSwpqmBZUSU1Dc0A5AwcwLTc9NaQP3Rwsm4iJn3evgJdg4wS1RJifa1DLhA8a2b15mqWrKtgSVEl735ezv8VbAIgPTGWqbm7e/BHZKcSqzNoJIKohy79mnOO9eV1fFhU0dqLX19eB8CAWB9HjRrItFDITxmZroubxHMachHZD1ur6/mwqIIl6yr4sKiS1ZurcQ78McbEnDSOGZ0R6vWnMzBRX7RK71KgixyE7Tub+Gh9ZWvIryjeTmNL8IvWw4YkB3vwoZDP1het0sMU6CJhVN/UwvKNVSwpCvbgP1pfyY42X7QeMzqDqbnpTMxO4/ChKSTEaphGwkdfioqEUUKsj2MPyeTYQzIBaG4JsHpzTesY/OLPynjh4xIAYgxGD0pi/LBUxg9LZcKwVMYNS2FoaoLOqJGwUw9dJMycc2yoqGPVpmoKN9dQWFpNYWk1xZU7W9sMTIxl/NDUUNCnMH5YKocOTlZvXrqkHrpILzIzRmUmMSoziTOPHNa6vLq+idWlwYBfvbmaVaU1PPHheuqbguPxvhhjTNbu3vz4YamMH5pCVkq8evPSLV0Gupk9DJwNbHXOTexk/SzgRWBdaNHzzrlfhLFGkaiQmhDLMaODX6Du0hJwFJXXBkM+FPZL1lXwYujceIDMpLh2Pfnxw1IZk5VMnF/nyEt73emhPwr8HliwjzaLnXNnh6UikX4k2CtPZkxWMmdP2r28qq6Rwja9+cLSGh57f33rbQxifcHtJrTtzQ9L0eP7+rkuA905l29mub1Qi4iEDEyMY/qYTKaPyWxd1twSYN22WlaVBgN+9eZq3v18G8+HvoAFyEqJbw33CcNSOXxoCqMyknRBVD8RrjH06Wa2HNgE/NA5t7KzRmZ2NXA1wMiRI8P00SL9g98Xw9ghKYwdksK5k3cvr6htbP3idVev/pHPy1vPlQcYmprAqMzE0CuJ3Myk1vkUPbc1anTrLJdQD/2vexlDTwUCzrkdZjYbuMc5N7arfeosF5Ge09QS4POyHazZXMOG8jqKyuvYUFFLUXkdZTUN7dpmJsUxKjOR3MwkRobed80PTIzVF7J9TI+e5eKcq24z/YqZ/cHMBjnnth3svkXkwMT6Yhg3NJVxQ1P3WFfb0Mz6NgG/vryWom11/HNdBS8UlNC2j5eS4G8X8LsCPzczUWff9EEHHehmNhTY4pxzZnYMEAOUH3RlItIjkuL9TMhOZUL2nmFf39RCcWUdRdvqWF8RCvvyOj4p2c6iTzfTEtid9gNifa3DNh1798PSBuDT4wB7XXdOW3wSmAUMMrNi4GdALIBz7gHgK8B3zKwZ2Alc7Ly6WklEDkpCrI9DB6dw6OCUPdY1tQTYVLWztVe/PvT+eVktb60uazdmH+eLYXjGgNaAH5mRSM7AAeSkDyBn4ADSBmgopyfoSlEROWgtAcfm6vrWoC8qr2V9m15+XWNLu/ZJcT5y0geQPXBAu6DfNT04JUE9/L3QlaIi0qN8MdYayDPGtF/nnKO8tpFNVTspqdxJSVXoFZpevrGKyrqmdtv4Y4yhaQmdhn3OwOA/BLpNwp4U6CLSo8yMQcnxDEqOZ9LwgZ22qW1oDgZ+h7DfVLWTDz4vZ3N1PYEOgwmDkuNaw71j2A9P75/DOgp0EfFcUry/9Rz7zjS1BNhSXd8u6EuqdlJcuZN/banhrTVbW++J07rPOF8w7NsE/bC0BAanJJCVEs/glPioOy1TgS4ifV6sL4bh6YkMT0/sdL1zjoraRjZV1VNSVUdx5c7W6ZKqnawo3k5FbWMn+zWykuPJSoknq03Qt3tPTWBQchzx/r4/xKNAF5GIZ2ZkJseTmRzPkcPTOm1T19jMluoGymoa2FpTH3pvaH0vqdpJwcZKymsb6exckYGJsWQlxzM4NT70ntBhPp6s5ARSB/g96/Ur0EWkX0iM8zN6kJ/Rg5L22a65JUB5bWP74K9uoGzH7vdlGyrZWt1AQ3Ngj+3j/DGtvf7dvf32vf8RGYlkJIX/ebQKdBGRNvy+GIakJjAkNQHovLcPwWGemobmDoFfT9mOBspC8+vL61i6vnKP4Z6rZx7CLbPHh7/2sO9RRKQfMDNSE2JJTYhlTFbyPts2Ngcorw0Ff00DOek98zBxBbqISA+L88cwLG0Aw9J6Jsh30SNPRESihAJdRCRKKNBFRKKEAl1EJEoo0EVEooQCXUQkSijQRUSihAJdRCRKKNBFRKKEAl1EJEoo0EVEooQCXUQkSijQRUSihAJdRCRKKNBFRKKEAl1EJEoo0EVEooQCXUQkSijQRUSihAJdRCRKKNBFRKKEAl1EJEp0Gehm9rCZbTWzT/ey3szsXjNba2YrzOyo8JcpIiJd6U4P/VHgjH2sPxMYG3pdDdx/8GWJiMj+6jLQnXP5QMU+mpwLLHBBHwADzWxYuAoUEZHuCccYeg6wsc18cWiZiIj0In9vfpiZXU1wWIaRI0f25keLiPSuQAACzbtfrgUCLcHp2AGQkBb2jwxHoJcAI9rMDw8t24Nzbj4wH2Dq1KkuDJ8tEh0CAQg0BX/ZW5p2h0BLU2h5S5vpZmhp3st0x22bwQWC27vA7lBpnQ69t7bZNd3F8k73t7fPCYBzYAZY6J028zEd1tk+1oW2s5i9tO9inXOh2toE7R7B23a+pf27a+l8+R7rm/f95/2lG+CUn4f9r1E4An0hcK2ZPQUcC2x3zpWGYb8iB8e54C9WcwO0NO5+Ne+abgiGXnNDh+mm0Pyutg0dttu1nzZtO9vnHkG8j1B2AW+PlcWEXj6I8bWZ7u7y0LqYmDbTHZbD7n8UcME/n13vrpNlres6WYbb/Q/FXtex57JddcX4gy/bNe1rv9wXC/6E3fMd18f4Q/tqu77Nu/k6X75ru2F5PfLH2GWgm9mTwCxgkJkVAz8DYgGccw8ArwCzgbVAHTCvRyqV/iEQgMYaqK+GhhpoCL3Xb28z3dm6NvPNDbtDnDD/R9AXB7548MeFpkMvf3wwBHyh9wGJweUxvuB8TGwoKPxtpmPbB0i76di9b+sLrWud7riP2FDbtqHl2x3a7YJ213IL73EST3QZ6M65S7pY74BrwlaRRCbnoGlnm2Ct7iR4q7tYVxMM865YDMSnQHxq6JUCyYMhc0xw2j+gTeB2Eb5t2/liQ+vatmu7bayCT/q0Xv1SVPog56CxNhSuNe1Dtt1reyfLOrTtatwQIDYxGMIJqbtDOWVYaL5NQLdd37F9XJKCVaQTCvRI17QTKot2h2t9Z2Hc2bKa3T3i7ozf+geEArXNKz1393RccpvQTesklEPvPv2VE+kp+u2KVI11sPQhePceqC3be7u45A5BnAopQ9uEbMfXXpb7YnvvZxORA6JAjzRNO2HZo7D4bqjdCoecBFMuhQED9wzjuOTgl14i0i8o0CNFUz18tAAW/xp2bIbRM2HWAhg13evKRKSPUKD3dc0N8PGfgj3y6hIYdTxc8EcYfYLXlYlIH6NA76tamqDgz5D/K9i+EUYcB3P/AKNP1BkeItIpBXpf09IEy5+C/DuhagPkTIVz7oExX1aQi8g+KdD7ipZm+OQZ+MedULkOsqfAWXfDoacoyEWkWxToXgu0wKfPwdt3QMXnMHQSXPI0HHa6glxE9osC3SuBFlj5Avzjf2Dbv2DIkXDRn2HcWQpyETkgCvTeFghA4YvBHnnZasgaD19dAOPO2X1HOhGRA6BA7y2BAKz+azDIt66EQYfDVx6BCXMV5CISFgr0nuYcrFkEb/83bP4EMg+FCx6CI87TVZwiElYK9J7iHHz2N3jrv6G0ANJHw3kPwsSv6AZVItIjlCzh5hys/XuwR16yDAaOgnP/AJMuUpCLSI9SwoSLc/DFW/DW7VD8IaSNhDm/g7xLdKdCEekVCvRwWJcfHFrZ8D6kDoezfwOTLw0+DUdEpJco0A9G0bvw9u1QtDj41J3Zv4Kjvh58dJmISC9ToB+IQAu8/ANY9ggkD4Ez/geOvhxiE7yuTET6MQX6/mppgv/7DnzyF5jxPTjpVogd4HVVIiIK9P3S3AB/mQdrXoaTfwYn3Oh1RSIirRTo3dVYB09/DT5/E868C4692uuKRETaUaB3R301PHERbPwAzr0v+AxPEZE+RoHelboKePz84GX7FzwEE8/3uiIRkU4p0PelZgv8aS6Ufx68te3hZ3hdkYjIXinQ96ZqIyw4F2o2w9eegUNmeV2RiMg+KdA7U/55MMzrq+GyF2DksV5XJCLSJQV6R1sLg2He0gTfWAjZk72uSESkW/RkhbY2fQyPzAYM5i1SmItIRFGg77LhA3hsDsQlwxWLYPA4rysSEdkvCnSAz9+CP50HyYODYZ5xiNcViYjst24FupmdYWZrzGytmd3cyfrLzazMzApCr2+Gv9QesmYRPPHV4BOF5i2CtOFeVyQickC6/FLUzHzAfcCpQDGwxMwWOudWdWj6tHPu2h6osed8+hw8fzUMnQSXPgeJGV5XJCJywLrTQz8GWOuc+8I51wg8BZzbs2X1go/+BM9eCSOOha+/qDAXkYjXnUDPATa2mS8OLevoAjNbYWbPmtmIznZkZleb2VIzW1pWVnYA5YbJBw/AwmthzJfha89CQqp3tYiIhEm4vhR9Cch1zk0CXgce66yRc26+c26qc25qVlZWmD56P+X/Cl79Dxh3NlzyJMQlelOHiEiYdSfQS4C2Pe7hoWWtnHPlzrmG0OwfgaPDU14YOQdv/Ce8eRsc+VW48DE9Kk5Eokp3An0JMNbMRptZHHAxsLBtAzMb1mZ2DlAYvhLDIBCARf8B79wdfFTceQ+CTxfJikh06TLVnHPNZnYt8BrgAx52zq00s18AS51zC4HrzGwO0AxUAJf3YM37J9ACC6+Dgsdh+rVw2i/BzOuqRETCzpxznnzw1KlT3dKlS3v2Q1qagqclrnweTrwZZt2sMBeRiGZmy5xzUztbF73jDk318JfL4V+L4NTb4PjrvK5IRKRHRWegN+yAp/4N1uXDWXfDtCu9rkhEpMdFX6DvrApeyl+8BM57APIu9roiEZFeEV2BXlsefGTc1sLgaYkT5nhdkYhIr4meQK8uDYZ5ZRFc8hSMPcXrikREelV0BHrVhuC9zGvLgjfZyv2S1xWJiPS6yA/0bWthwRxo3BG8ydbwTs/mERGJepEd6FtWwoK54AJw+csw9EivKxIR8UzkPrGoZFnw+Z8xfrjiVYW5iPR7kRnoRe/CY+fCgIHBR8YNGut1RSIinou8QP/ibXj8AkgdFnxkXHqu1xWJiPQJkRfoqTkwakYwzFOzva5GRKTPiLwvRQeNhcue97oKEZE+J/J66CIi0ikFuohIlFCgi4hECQW6iEiUUKCLiEQJBbqISJRQoIuIRAkFuohIlDDnnDcfbFYGrD/AzQcB28JYTqTT8WhPx2M3HYv2ouF4jHLOZXW2wrNAPxhmttQ5pxufh+h4tKfjsZuORXvRfjw05CIiEiUU6CIiUSJSA32+1wX0MToe7el47KZj0V5UH4+IHEMXEZE9RWoPXUREOlCgi4hEiYgLdDM7w8zWmNlaM7vZ63q8ZGYjzOwtM1tlZivN7Pte1+Q1M/OZ2cdm9leva/GamQ00s2fNbLWZFZrZdK9r8oqZ3RD6HfnUzJ40swSva+oJERXoZuYD7gPOBCYAl5jZBG+r8lQz8APn3ATgOOCafn48AL4PFHpdRB9xD/Cqc24ckEc/PS5mlgNcB0x1zk0EfMDF3lbVMyIq0IFjgLXOuS+cc43AU8C5HtfkGedcqXPuo9B0DcFf2Bxvq/KOmQ0HzgL+6HUtXjOzNGAm8BCAc67ROVflaVHe8gMDzMwPJAKbPK6nR0RaoOcAG9vMF9OPA6wtM8sFpgD/9LgUL/0WuAkIeFxHXzAaKAMeCQ1B/dHMkrwuygvOuRLgV8AGoBTY7pz7m7dV9YxIC3TphJklA88B1zvnqr2uxwtmdjaw1Tm3zOta+gg/cBRwv3NuClAL9MvvnMwsneD/5EcD2UCSmV3qbVU9I9ICvQQY0WZ+eGhZv2VmsQTD/M/Ouee9rsdDxwNzzKyI4FDcl83scW9L8lQxUOyc2/U/tmcJBnx/dAqwzjlX5pxrAp4HZnhcU4+ItEBfAow1s9FmFkfwi42FHtfkGTMzgmOkhc65u72ux0vOuR8554Y753IJ/r140zkXlb2w7nDObQY2mtnhoUUnA6s8LMlLG4DjzCwx9DtzMlH6BbHf6wL2h3Ou2cyuBV4j+E31w865lR6X5aXjgcuAT8ysILTsFufcK96VJH3I94A/hzo/XwDzPK7HE865f5rZs8BHBM8M+5govQWALv0XEYkSkTbkIiIie6FAFxGJEgp0EZEooUAXEYkSCnQRkSihQBcRiRIKdBGRKPH/RXZOuHVsoR8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9744\n"
     ]
    }
   ],
   "source": [
    "# 导入必要的库\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms, models\n",
    "from tqdm import *\n",
    "import numpy as np\n",
    "import sys\n",
    "\n",
    "# 设置随机种子\n",
    "torch.manual_seed(0)\n",
    "\n",
    "# 定义模型、优化器、损失函数\n",
    "model = LeNet()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.02)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# 设置数据变换和数据加载器\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),  # 将数据转换为张量\n",
    "    transforms.Normalize((0.5,), (0.5,))  # 对数据进行归一化\n",
    "])\n",
    "train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)  # 加载训练数据\n",
    "train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)  # 实例化训练数据加载器\n",
    "test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)  # 加载测试数据\n",
    "test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)  # 实例化测试数据加载器\n",
    "\n",
    "# 设置epoch数并开始训练\n",
    "num_epochs = 10  # 设置epoch数\n",
    "loss_history = []  # 创建损失历史记录列表\n",
    "acc_history = []   # 创建准确率历史记录列表\n",
    "\n",
    "# tqdm用于显示进度条并评估任务时间开销\n",
    "for epoch in tqdm(range(num_epochs), file=sys.stdout):\n",
    "    # 记录损失和预测正确数\n",
    "    total_loss = 0\n",
    "    total_correct = 0\n",
    "    \n",
    "    # 批量训练\n",
    "    model.train()\n",
    "    for inputs, labels in train_loader:\n",
    "\n",
    "        # 预测、损失函数、反向传播\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        # 记录训练集loss\n",
    "        total_loss += loss.item()\n",
    "    \n",
    "    # 测试模型，不计算梯度\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for inputs, labels in test_loader:\n",
    "\n",
    "            # 预测\n",
    "            outputs = model(inputs)\n",
    "            # 记录测试集预测正确数\n",
    "            total_correct += (outputs.argmax(1) == labels).sum().item()\n",
    "        \n",
    "    # 记录训练集损失和测试集准确率\n",
    "    loss_history.append(np.log10(total_loss))  # 将损失加入损失历史记录列表，由于数值有时较大，这里取对数\n",
    "    acc_history.append(total_correct / len(test_dataset))# 将准确率加入准确率历史记录列表\n",
    "    \n",
    "    # 打印中间值\n",
    "    if epoch % 2 == 0:\n",
    "        tqdm.write(\"Epoch: {0} Loss: {1} Acc: {2}\".format(epoch, loss_history[-1], acc_history[-1]))\n",
    "\n",
    "# 使用Matplotlib绘制损失和准确率的曲线图\n",
    "import matplotlib.pyplot as plt\n",
    "plt.plot(loss_history, label='loss')\n",
    "plt.plot(acc_history, label='accuracy')\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "# 输出准确率\n",
    "print(\"Accuracy:\", acc_history[-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d366668d",
   "metadata": {},
   "source": [
    "这段代码使用了MNIST数据集，它包含手写数字的图像和对应的标签。它使用了LeNet模型，并使用了SGD优化器和交叉熵损失函数。每个epoch结束后，它会在测试集上评估模型，并输出测试集上的损失和准确率。如果你想使用其他数据集，你可以替换MNIST数据集，并在train_dataset和test_dataset中使用你想要使用的数据集。你可以在PyTorch的torchvision包中找到许多常见数据集，或者你也可以从外部加载数据。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df972407",
   "metadata": {},
   "source": [
    "### 7.5.3. 小结\n",
    "- 卷积神经网络（CNN）是一种深度学习模型，常用于图像分类和特征提取。\n",
    "- CNN由卷积层、汇聚层、非线性激活函数和全连接层组成，通过卷积计算、降维和学习非线性关系来提取输入数据的特征。\n",
    "- 下面是几种提高卷积神经网络性能的方法：\n",
    "    1. 堆叠更多的卷积层\n",
    "    2. 使用较小的卷积核\n",
    "    3. 增加卷积层的通道数\n",
    "    4. 调整学习率\n",
    "    5. 调整正则化参数\n",
    "    6. 调整批量大小\n",
    "    7. 使用数据增强\n",
    "- LeNet是最早发布的卷积神经网络之一，虽然发布已经有些时日，但它仍然是一个很好的入门模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af348896",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
